from tensorflow.keras.applications import VGG16, VGG19, ResNet50, InceptionV3, Xception
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.preprocessing import image
from keras.applications import imagenet_utils
from keras.applications.inception_v3 import preprocess_input
import tensorflow as tf
from tensorflow import keras

import numpy as np
import pandas as pd
import h5py

MODELS = {
    "vgg16": VGG16,
    "vgg16_cam": VGG16,
    "vgg19": VGG19,
    "inception": InceptionV3,
    "xception": Xception, # TensorFlow ONLY
    "resnet": ResNet50,
    "simple": "simple"
}

def print_arguments(args):
    for key in args:
        print("Argument name: ", key, " Argument value:", args[key])

# creates the desired model architecture
def create_model(args):
    img_size = (224, 224, 3)
    if args["model"] in ("simple", "inception", "xception"):
        img_size = (299, 299, 3)
        
    if args["model"] not in MODELS.keys():
        raise AssertionError("The --model command line argument should "
    	"be a key in the `MODELS` dictionary")

    model = models.Sequential()

    if args["model"] == "simple":
        model.add(layers.Conv2D(32, kernel_size=5, strides=(3,3), activation='relu', input_shape=(299, 299, 3))),
        model.add(layers.Conv2D(64, kernel_size=3, strides=(2,2), activation='relu'))
        model.add(layers.MaxPooling2D(pool_size=(2, 2)))
        model.add(layers.Dropout(0.25))
    else:
        Network = MODELS[args["model"]]
        conv_base = Network(weights="imagenet", include_top=False, input_shape=img_size)
        if args["model"] == "vgg16_cam":
            conv_base = models.Model(conv_base.input, conv_base.layers[-2].output)
        conv_base.trainable = False
        model.add(conv_base)

    if args["model"] == "resnet":
        model.add(layers.AveragePooling2D(pool_size=(7,7)))
        model.add(layers.Flatten())
        model.add(layers.Dense(2048, activation='relu'))
        if args["mode"] == "train":
            model.add(layers.Dropout(args["dropout_rate"]))
        model.add(layers.Dense(1, activation='sigmoid'))

    elif args["model"] == "vgg16":
        model.add(layers.Flatten())
        model.add(layers.Dense(256, activation='relu'))
        if args["mode"] == "train":
            model.add(layers.Dropout(args["dropout_rate"]))
        model.add(layers.Dense(1, activation='sigmoid'))

    elif args["model"] == "vgg16_cam":
        model.add(layers.AveragePooling2D(pool_size=(14, 14)))
        model.add(layers.Flatten())
        model.add(layers.Dense(512, activation='relu'))
        if args["mode"] == "train":
            model.add(layers.Dropout(args["dropout_rate"]))
        model.add(layers.Dense(1, activation='sigmoid'))

    else:
        model.add(layers.Flatten())
        model.add(layers.Dense(256, activation='relu'))
        if args["mode"] == "train":
            model.add(layers.Dropout(args["dropout_rate"]))
        model.add(layers.Dense(1, activation='sigmoid'))

    model.compile(loss='binary_crossentropy',
    	      optimizer=tf.keras.optimizers.RMSprop(lr=args["lr"], decay=args["decay"]),
    	      metrics=['acc'])

    return model

# loads train, val and test data
def load_data(args):
    train_images, train_labels = load_data_helper(hdf5_file_name = args["train_filename"], train = True)
    val_images, val_labels = load_data_helper(hdf5_file_name = args["test_filename"], train = False)

    return train_images, train_labels, val_images, val_labels

def calculate_accuracy(pred, label):
    pred = np.array(pred)
    label = np.array(label)
    num_correct = np.sum(pred == label)
    accuracy = float(num_correct) / pred.shape[0]
    
    return accuracy

def load_data_helper(hdf5_file_name, train):
    def unison_shuffled_copies(a, b):
        assert a.shape[0] == b.shape[0]
        p = np.random.permutation(a.shape[0])
        
        return a[p], b[p]

    f = h5py.File(hdf5_file_name, 'r')
    images = np.array(f["image"])
    labels = np.array(f["label"])
    f.close()
    
    if train:
        images, labels = unison_shuffled_copies(a = images, b = labels)
    
    print()
    print("Images shape: ", images.shape)
    print("Labels shape: ", labels.shape)
    print()
    
    print("Images datatype: ", images.dtype)
    print("Labels datatype: ", labels.dtype)
    print()
    
    return images, labels

def imgs_input_fn(images, labels, args, perform_shuffle, repeat_count, batch_size, training, input_name=''):
    img_size = (224, 224, 3)
    if args["model"] in ("simple", "inception", "xception"):
        img_size = (299, 299, 3)
        
    images = tf.convert_to_tensor(value = images)
    images = tf.image.resize_images(images, img_size[:2])
        
    def _parse_function(image, label):
        image = tf.subtract(image, 116.779) # Zero-center by mean pixel
        
        if training:
            # Data Augmentation
            image = tf.image.random_flip_left_right(image)
            image = tf.image.random_flip_up_down(image)
            # image = tf.image.random_hue(image, max_delta=0.5)
            image = tf.image.random_brightness(image, max_delta=0.3)
            image = tf.image.random_contrast(image, lower=0.7, upper=1.3)
            
        image = tf.reverse(image, axis=[2]) # 'RGB'->'BGR'
        
        d = dict(zip([input_name], [image])), label
        return d

    # Expand the shape of "labels" if necessory
    if len(labels.shape) == 1:
        labels = np.expand_dims(labels, axis=1)
        labels = tf.constant(labels)
        labels = tf.cast(labels, tf.float32)
        dataset = tf.data.Dataset.from_tensor_slices((images, labels))
        dataset = dataset.map(_parse_function)

    if training:
        dataset = dataset.repeat(repeat_count)  # Repeats dataset this # times
        if perform_shuffle:
            # Randomizes input using a window of N elements (read into memory)
            # N is given as an argument to the script (args["buffer_size"])
            dataset = dataset.shuffle(buffer_size=args["buffer_size"])

    dataset = dataset.batch(batch_size)  # Batch size to use
    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    
    return batch_features, batch_labels
